# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:hydrogen
#     text_representation:
#       extension: .py
#       format_name: hydrogen
#       format_version: '1.3'
#       jupytext_version: 1.13.8
#   kernelspec:
#     display_name: Python (convex_nn_1.1.15)
#     language: python
#     name: convex_nn_1.1.15
# ---

# %%
import numpy as np
import cvxpy as cp
from tqdm import tqdm
from scipy.sparse.linalg import LinearOperator, lsqr, lsmr
from sklearn.kernel_ridge import KernelRidge

from functools import partial

# %%
import matplotlib.pyplot as plt

# %%
plt.rcParams["figure.figsize"] = 12, 9
plt.rcParams["font.size"] = 18
plt.rcParams["axes.grid"] = True

# %%
!mkdir figs


# %%
def get_unique_masks(
    X,
    G=None,
    max_neurons=1000,
    bias=True,
    kink_eps=None,
):
    """get_unique_masks.

    :param X: data matrix, shape (n, d)
    :param G: gates to generate masks, shape (num_gates, d)
        if this is None, random gates will be sampled
    :param max_neurons: number of random gates to generate
    :param bias: whether the last dimension of the data is all 1s (aka bias)

    Returns:
        D: unique masks, shape (p, n)
        G: gate vectors that generate the masks in D, shape (p, d)
    """
    n, d = X.shape
    if G is None:
        if d == 2 and bias:
            # data is effecitively 1-d, so we can enumerate all 2*n gates
            assert (X == X[np.argsort(X[:, 0])]).all()
            if kink_eps is not None:
                # sample gates exactly at data points
                G = np.random.rand(1, n - 1) * (kink_eps) + X[:-1, 0]
                G = np.append(X[-1, 0] - kink_eps * np.ones((1, 1)), G, axis=1)
                # G2 = np.random.rand(1, n - 1) * (-kink_eps) + X[:-1, 0]
                # G2 = np.append(X[-1, 0] + kink_eps * np.ones((1, 1)), G2, axis=1)
                # G = np.append(G, -G2, axis=1)
                # G = np.append(np.ones((1, 2 * n)), G, axis=0)
                G = np.append(np.ones((1, n)), -G, axis=0)
                G = np.append(G, -G, axis=1)
                G += np.random.rand(2 * n)
            else:
                # sample gates uniformly randomly between data points
                G = np.random.rand(1, n - 1) * (X[1:, 0] - X[:-1, 0]) + X[:-1, 0]
                G = np.append(np.random.rand(1, 1) + X[0, 0] - 1, G, axis=1)
                G = np.append(np.ones((1, n)), -G, axis=0)
                G = np.append(G, -G, axis=1)
        else:
            G = np.random.randn(d, max_neurons)
    else:
        G = G.T
    D = (X @ G >= 0).astype(np.int32)
    # plt.imshow(D)
    # plt.show()
    D, ii = np.unique(D, return_index=True, axis=1)
    G = G[:, ii]
    return D.T, G.T


# %%
def get_data(n, d, max_neurons=1000, bias=True):
    """Generate gaussian data with masks.

    :param n: num samples
    :param d: data dim
    :param max_neurons: max number of gates (masks)
    :param bias: whether one of the data dims is a bias (aka all ones)

    Returns:
        X: data matrix, shape (n, d)
        D: binary masks, shape (p, n)
        G: gates that generate D, shape (p, d)

    """
    X = np.random.randn(n, d - 1 if bias else d)
    if bias:
        X = np.append(X, np.ones((n, 1)), axis=1)
    if d == 2 and bias:
        # sort data
        X = X[np.argsort(X[:, 0], axis=0)]
    D, G = get_unique_masks(X, max_neurons=max_neurons, bias=bias)
    return X, D, G


# %%
class MaskedDataOp(LinearOperator):
    def __init__(self, X, D):
        self.X = X
        self.D = D
        self.n, self.d = self.X.shape
        self.p = self.D.shape[0]
        self.shape = (self.n, self.p * self.d)
        self.dtype = np.float64

    def _matvec(self, v):
        v = v.reshape(self.p, self.d)
        return np.einsum(
            "nd, pd, pn -> n",
            self.X,
            v,
            self.D,
        )

    def _rmatvec(self, v):
        return np.einsum(
            "pn, n, nd -> pd",
            self.D,
            v,
            self.X,
        ).reshape(-1)


# %%
def solve_l2_grelu(
    X,
    D,
    y,
    weights,
    L=0.02,
    solver=lsqr,
):
    """
    X: data matrix of shape (n, d)
    D: masks of shape (p, n)
    y: targets
    weights: regularization weights for each mask of shape (p, 1)
    L: scalar which is global l2-regularization weight

    Therefore, each mask is effecitively regularized with weight L*weights[i]
    """
    A = MaskedDataOp(X, D / np.sqrt(weights).reshape(-1, 1))
    p = D.shape[0]
    d = X.shape[1]
    w_opt, istop, itn = solver(
        A=A,
        b=y,
        damp=np.sqrt(L),
        # maxiter=max_iters,
        # atol=tol,
        # btol=tol,
    )[0:3]
    sol = w_opt.reshape(p, d) / np.sqrt(weights).reshape(-1, 1)
    return sol  # .reshape(-1)


# %%
def solve_l2_grelu_cvxpy(
    X,
    D,
    y,
    weights,
    L=0.02,
):
    """
    X: data matrix of shape (n, d)
    D: masks of shape (p, n)
    y: targets
    weights: regularization weights for each mask of shape (p, 1)
    L: scalar which is global l2-regularization weight

    Therefore, each mask is effecitively regularized with weight L*weights[i]
    """
    p = D.shape[0]
    d = X.shape[1]
    w = cp.Variable((p, d))
    pred = sum(cp.multiply(di, (X @ wi)) for di, wi in zip(D, w))
    objective = cp.norm2(pred - y) ** 2 + L * cp.sum(
        cp.multiply(cp.norm2(w, axis=1) ** 2, weights.reshape(-1))
    )
    prob = cp.Problem(cp.Minimize(objective))
    prob.solve()
    return w.value


# %%
def gated_relu_model(x, G, W):
    return ((G @ x.T >= 0) * (W @ x.T)).sum(axis=0)


# %%
def lasso_objective(X, D, w, y, L):
    pred = sum(di * (X @ wi) for di, wi in zip(D, w))
    objective = np.linalg.norm(pred - y) ** 2 + L * np.linalg.norm(w, axis=1).sum()
    return objective


# %%
def model_loss(model, X, y):
    return ((model(X) - y) ** 2).mean()


# %% [markdown]
# Solve unconstrained Group-Lasso problem corresponding to gated relu

# %%
def train_l1_gated_relu(X, G, y, L=0.02):
    D, G = get_unique_masks(X, G)
    n, d = X.shape
    p, d = G.shape
    w = cp.Variable((p, d))
    pred = sum(cp.multiply(di, (X @ wi)) for di, wi in zip(D, w))
    objective = cp.norm2(pred - y) ** 2 + L * cp.mixed_norm(w, 2, 1)
    prob = cp.Problem(cp.Minimize(objective))
    print(prob.solve())
    w_opt = w.value.copy()
    model = partial(gated_relu_model, G=G, W=w_opt)
    return model, w_opt


# %% [markdown]
# Solve it using reweighted least squares

# %%
def train_l2_gated_relu(X, G, y, weights, L=0.02):
    D, G = get_unique_masks(X, G)
    w_opt = solve_l2_grelu_cvxpy(X, D, y, weights, L=L)
    # w_opt = solve_l2_grelu(X, D, y, weights, L=L)
    model = partial(gated_relu_model, G=G, W=w_opt)
    return model, w_opt


# %%
def train_l1_gated_relu_reweighted(
    X,
    G,
    y,
    L=0.02,
    maxiters=100,
    eps=1e-10,
    init_weights=None,
):
    sols = []
    D, G = get_unique_masks(X, G)
    t = tqdm(range(maxiters), desc="Iteratively reweighted least squares")
    n, d = X.shape
    p, d = G.shape
    # weights = np.random.rand(p, 1)  # * eps
    weights = np.ones((p, 1)) if init_weights is None else init_weights.copy()
    for i in t:
        modelk, Wk = train_l2_gated_relu(X, G, y, weights, L=L)  # .reshape(p, d)
        sols.append(Wk.copy())
        weights = 1 / np.sqrt(np.linalg.norm(Wk, axis=1, keepdims=True) ** 2 + eps)
        loss = model_loss(modelk, X, y)
        t.set_postfix(train_loss=loss)
    model = partial(gated_relu_model, G=G, W=Wk)
    return model, sols


# %%
def train_l1_squared_gated_relu(X, G, y, L=0.02):
    D, G = get_unique_masks(X, G)
    n, d = X.shape
    p, d = G.shape
    w = cp.Variable((p, d))
    pred = sum(cp.multiply(di, (X @ wi)) for di, wi in zip(D, w))
    objective = cp.norm2(pred - y) ** 2 + L * cp.mixed_norm(w, 2, 1) ** 2
    prob = cp.Problem(cp.Minimize(objective))
    prob.solve()
    w_opt = w.value.copy()
    model = partial(gated_relu_model, G=G, W=w_opt)
    return model, w_opt


# %%
def ntk(x1, x2):
    """ntk.

    :param x1: shape (n1, d)
    :param x2: shape (n2, d)

    Returns:
        K: All pairwise ntk evaluations, shape (n1, n2)
    """
    normed_x1 = x1 / np.linalg.norm(x1, axis=1, keepdims=True)
    normed_x2 = x2 / np.linalg.norm(x2, axis=1, keepdims=True)
    pairwise_angles = np.arccos(
        np.clip(
            normed_x1 @ normed_x2.T,
            a_min=-1,
            a_max=1,
        )
    )
    K = (np.pi - pairwise_angles) * (x1 @ x2.T) / (2 * np.pi)
    return K


# %%
def predict_with_ntk(test_X, krr, train_X):
    test_K = ntk(test_X, train_X)
    return krr.predict(test_K)


# %%
def train_ntk_with_ridge(X, y, L=0.02):
    K = ntk(X, X)
    krr = KernelRidge(alpha=L, kernel="precomputed")
    krr.fit(K, y)
    model = partial(predict_with_ntk, krr=krr, train_X=X)
    return model


# %% [markdown]
# Computing weights induced by NTK

# %%
def get_cone_probabilities(
    X,
    D,
    mc_samples=int(1e5),
    mc_repeats=10,
):
    n, d = X.shape
    p = D.shape[0]

    probs = np.zeros(p)
    stds = np.zeros(p)
    desc = "Calculating cone probabilities"
    if d == 2:
        for i in tqdm(range(p), desc=desc):
            mat = (2 * D[i] - 1).reshape(n, 1) * X
            mat /= np.linalg.norm(mat, axis=1, keepdims=True)
            pairwise_angles = np.arccos(np.clip(mat @ mat.T, a_min=-1, a_max=1))
            probs[i] = (np.pi - pairwise_angles.max()) / (2 * np.pi)

        return probs, stds

    h = np.random.randn(d, mc_samples, mc_repeats)
    for i in tqdm(range(p), desc=desc):
        mat = (2 * D[i] - 1).reshape(n, 1) * X
        ps = (np.einsum("ij, jkl->ikl", mat, h) > 0).all(axis=0).sum(
            axis=0
        ) / mc_samples
        probs[i] = ps.mean()
        stds[i] = ps.std()
        # print(i, ps.mean(), ps.std())

    return probs, stds


# %% [markdown]
# ---

# %%
las_color = "blue"
rew_color = "red"
ntk_color = "green"
krr_color = "purple"

lw = 5
irlslw = 3
ls = "-o"

SAVE = True
# SAVE = False

# %% [markdown]
# # 1-D example

# %%


def relu(x):
    return np.maximum(0, x)


def drelu(x):
    return x >= 0


# %%
X = np.array(
    [
        [-2, 1],
        [-1, 1],
        [0, 1],
        [1, 1],
        [2, 1.0],
    ]
)
y = np.array([-1, 1, 1, 1, -1])
n, d = X.shape
oned_xlims = (-2.5, 2.5)
oned_ylims = (-1.5, 1.5)
eps = 1e-8
repeats = 10
L = 1e-4

xlims = oned_xlims
ylims = oned_ylims
rew_iters = 10

G1 = np.concatenate((X[:, 0], -X[:, 0]), axis=0).reshape(1, -1)
G2 = np.concatenate((-np.ones((n, 1)), np.ones((n, 1))), axis=0).reshape(1, -1)
G = np.concatenate((G1, G2), axis=0)
dmat = (X @ G >= 0) * 1.0
test = np.linspace(-2, 2, 100).reshape(-1, 1)
test = np.sort(np.append(test, np.array([-1]).reshape(-1, 1))).reshape(-1, 1)
Xtest = np.concatenate((test, np.ones_like(test)), axis=1)
dmat_test = (Xtest @ G >= 0) * 1.0

# %% [markdown]
# ## Gated ReLU with cone decomp

# %%
m1 = dmat.shape[1]
Uopt1 = cp.Variable((d, m1))
Uopt2 = cp.Variable((d, m1))

yopt1 = cp.Parameter((n, 1))
yopt2 = cp.Parameter((n, 1))

yopt1 = cp.sum(cp.multiply(dmat, X @ Uopt1), axis=1)
yopt2 = cp.sum(cp.multiply(dmat, X @ Uopt2), axis=1)

cost = cp.sum_squares(y - yopt1 + yopt2) / n + L * (
    cp.mixed_norm(Uopt1.T, 2, 1) + cp.mixed_norm(Uopt2.T, 2, 1)
)
constraints = []
prob = cp.Problem(cp.Minimize(cost), constraints)
prob.solve(verbose=False)

cvx_opt = prob.value
print("Gated objective: ", cvx_opt)
W1 = Uopt1.value
W2 = Uopt2.value


############## Cone decomposition ##############
W1c = np.zeros_like(W1)
W2c = np.zeros_like(W2)

for i in range(dmat.shape[1]):
    wc1 = cp.Variable((d, 1))
    wc2 = cp.Variable((d, 1))

    cost = 0

    dmatre = dmat[:, i].reshape(
        n,
    )
    ycone = (dmatre * (X @ W1[:, i]) - dmatre * (X @ W2[:, i])).reshape(-1, 1)

    constraints = []
    constraints += [np.diag(2 * dmatre - np.ones((n,))) @ (X @ wc1) >= 0]
    constraints += [np.diag(2 * dmatre - np.ones((n,))) @ (X @ wc2) >= 0]
    constraints += [np.diag(dmatre) @ ((X @ wc1) - (X @ wc2)) == ycone]
    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve(solver=cp.MOSEK, verbose=False)

    W1c[:, i] = wc1.value.reshape(-1)
    W2c[:, i] = wc2.value.reshape(-1)

ytest_cone = np.sum(relu(Xtest @ W1c) - relu(Xtest @ W2c), axis=1)

# %% [markdown]
# ## GD on two layer ReLU net

# %%
######### GD Non-convex #########

sigma = 0.01
mu = 0.01
mbp = 200
epoch = 200000

beta_bp = L
obj_bp = []

seed = 42
np.random.seed(seed=seed)
W1bp = sigma * np.random.randn(d, mbp)
W2bp = sigma * np.random.randn(mbp, 1)
yall = y.copy()
for i_epoch in range(epoch):
    Xgd = X
    ygd = y.reshape(-1, 1)
    yest = relu(Xgd @ W1bp) @ W2bp
    obj_bp.append(
        np.linalg.norm(yest - ygd, "fro") ** 2 / (2 * n)
        + (beta_bp / 2)
        * (np.linalg.norm(W1bp, "fro") ** 2 + np.linalg.norm(W2bp, "fro") ** 2)
    )
    gradW2 = relu(Xgd @ W1bp).T @ (yest - ygd).reshape(-1, 1) / n
    gradW1 = (
        Xgd.T
        @ (drelu(Xgd @ W1bp) * ((yest - ygd).reshape(-1, 1) @ W2bp.reshape(1, -1)))
        / n
    )

    W1bp = (1 - mu * beta_bp) * W1bp - mu * gradW1
    W2bp = (1 - mu * beta_bp) * W2bp - mu * gradW2


print("Final objbbective for seed :", obj_bp[-1])
ytest_sgd = relu(Xtest @ W1bp) @ W2bp

plt.semilogy()
plt.plot(obj_bp)
plt.title("Train loss for sgd")
plt.show()

# %% [markdown]
# ## Kernel Ridge Regression with NTK

# %%
model = train_ntk_with_ridge(X, y, L=L)
ytest_ntk = model(Xtest)

# %% [markdown]
# ## IRLS, Group Lasso, and NTK weighted program

# %%
G1 = np.concatenate((X[:, 0], -X[:, 0]), axis=0).reshape(1, -1)
G2 = np.concatenate((-np.ones((n, 1)), np.ones((n, 1))), axis=0).reshape(1, -1)
G = np.concatenate((G1, G2), axis=0)
G = G.T

# %%
n, d = X.shape
D, G = get_unique_masks(X, G)
# G = G.T
p, d = G.shape
print(p, d)
probs = get_cone_probabilities(X, D)[0]
print(f"Sum of cone probabilities = {probs.sum()}")
ntk_weights = 1 / probs

las_model, las_sol = train_l1_gated_relu(
    X, G, y, L=L  # * 2 * np.linalg.norm(ntk_sol, axis=1).sum()
)
# ntk_model, ntk_sol = train_l1_squared_gated_relu(
#     X, G, y, L=L / (2 * np.linalg.norm(las_sol, axis=1).sum())
# )
rew_model, rew_sols = train_l1_gated_relu_reweighted(
    X,
    G,
    y,
    L=L,
    eps=eps,
    init_weights=ntk_weights,
    maxiters=rew_iters,
)
ntk_model, ntk_sol = train_l2_gated_relu(X, G, y, weights=ntk_weights, L=L)
krr_model = train_ntk_with_ridge(X, y, L=L)
ntk_krr_err = np.linalg.norm(ntk_model(X) - krr_model(X)) / np.linalg.norm(krr_model(X))
print(f"Error between KRR and NTK on train data: {ntk_krr_err}")

rand_rew_sols = [
    train_l1_gated_relu_reweighted(
        X,
        G,
        y,
        L=L,
        eps=eps,
        maxiters=rew_iters,
        init_weights=np.random.rand(*ntk_weights.shape) + eps,
    )[1]
    for _ in range(repeats)
]

all_models = [las_model, rew_model, ntk_model, krr_model]
all_sols = [las_sol, rew_sols[-1], ntk_sol]
model_names = ["Group lasso", "Reweighted", "NTK weights", "KRR with NTK"]
colors = [las_color, rew_color, ntk_color, krr_color]

sample_n = 1000
x = np.linspace(*xlims, sample_n).reshape(-1, 1)
sample_X = np.append(x, np.ones((sample_n, 1)), axis=1)
plt.subplots(1, 2, figsize=(18, 6))
plt.subplot(1, 2, 1)
plt.semilogy()
plt.title(f"Group lasso objective, $\\lambda={L}$")
rew_lobjs = [lasso_objective(X, D, rew_sol, y, L) for rew_sol in rew_sols]
rand_rew_lobjs = [
    [lasso_objective(X, D, rew_sol, y, L) for rew_sol in rew_sols]
    for rew_sols in rand_rew_sols
]
rand_rew_lobjs = np.array(rand_rew_lobjs)
plt.plot(
    rand_rew_lobjs[1:].T,
    color="black",
    alpha=0.5,
    linewidth=irlslw,
)
plt.plot(
    rand_rew_lobjs[0],
    color="black",
    alpha=0.5,
    label="IRLS with random inits",
    linewidth=irlslw,
)
plt.plot(
    rew_lobjs,
    # label=model_names[1],
    ls,
    label="IRLS with NTK init",
    color=rew_color,
    linewidth=lw,
)
las_lobj = lasso_objective(X, D, las_sol, y, L)
plt.axhline(
    las_lobj,
    # label=model_names[0],
    linestyle="--",
    label="Optimal",
    color=las_color,
    linewidth=lw,
)
ntk_lobj = lasso_objective(X, D, ntk_sol, y, L)
plt.axhline(
    ntk_lobj,
    # label=model_names[2],
    linestyle="--",
    label=r"NTK weights",
    color=ntk_color,
    linewidth=lw,
)
plt.legend()
plt.xlabel("Iteration $k$")

ax = plt.subplot(1, 2, 2)
plt.title(f"Learned functions")
# plt.title(f"$\\lambda = {L}$")
ytest_c, ytest_s, ytest_n = ytest_cone, ytest_sgd, ytest_ntk
xtest = np.linspace(-2, 2, 101)
plt.plot(
    xtest,
    ytest_c,
    label="Group Lasso",
    color=las_color,
    lw=lw,
    alpha=0.7,
)
plt.plot(
    xtest,
    ytest_s,
    "--",
    label="Gradient Descent",
    color="black",
    lw=lw,
    alpha=0.99,
)
plt.plot(
    xtest,
    ytest_n,
    label="KRR with NTK",
    color="green",
    lw=lw,
    alpha=0.7,
)
plt.plot(
    X[:, 0],
    y,
    marker="o",
    markersize=10,
    linewidth=0,
    alpha=0.95,
    label="Training data",
    color="red",
)
l = plt.legend(bbox_to_anchor=[1.0, 1.0])
plt.ylim(ylims)
plt.xlim(xlims)
prefix = "one_d_example"
if SAVE:
    plt.savefig(
        f"figs/{prefix}_L={L}.png",
        bbox_inches="tight",
        bbox_extra_artists=[l],
    )
plt.show()
print(f"Lasso objectives:")
print(f"{model_names[0]}=\t{las_lobj}")
print(f"{model_names[1]}=\t{rew_lobjs[-1]}")
print(f"{model_names[2]}=\t{ntk_lobj}")


# %% [markdown]
# # Student teacher

# %%
def run_st_models_and_plot(
    X,
    y,
    Ls=[0.02],
    rew_iters=30,
    eps=1e-10,
    repeats=2,
    prefix="",
):
    D, G = get_unique_masks(X)
    n, d = X.shape
    p, d = G.shape
    probs = get_cone_probabilities(X, D)[0]
    print(f"Sum of cone probabilities = {probs.sum()}")
    ntk_weights = 1 / probs
    model_names = ["Group lasso", "Reweighted", "NTK weights", "KRR with NTK"]
    colors = [las_color, rew_color, ntk_color, krr_color]

    plt.subplots(1, len(Ls), figsize=(18, 6))
    plt.suptitle("Group Lasso objective for Student-Teacher setting")
    for i, L in enumerate(Ls):
        las_model, las_sol = train_l1_gated_relu(
            X, G, y, L=L  # * 2 * np.linalg.norm(ntk_sol, axis=1).sum()
        )
        # ntk_model, ntk_sol = train_l1_squared_gated_relu(
        #     X, G, y, L=L / (2 * np.linalg.norm(las_sol, axis=1).sum())
        # )
        rew_model, rew_sols = train_l1_gated_relu_reweighted(
            X,
            G,
            y,
            L=L,
            eps=eps,
            init_weights=ntk_weights,
            maxiters=rew_iters,
        )
        ntk_model, ntk_sol = train_l2_gated_relu(X, G, y, weights=ntk_weights, L=L)
        krr_model = train_ntk_with_ridge(X, y, L=L)
        ntk_krr_err = np.linalg.norm(ntk_model(X) - krr_model(X)) / np.linalg.norm(
            krr_model(X)
        )
        print(f"Error between KRR and NTK on train data: {ntk_krr_err}")

        rand_rew_sols = [
            train_l1_gated_relu_reweighted(
                X,
                G,
                y,
                L=L,
                eps=eps,
                maxiters=rew_iters,
                init_weights=np.random.rand(*ntk_weights.shape) + eps,
            )[1]
            for _ in range(repeats)
        ]

        all_models = [las_model, rew_model, ntk_model, krr_model]
        all_sols = [las_sol, rew_sols[-1], ntk_sol]
        plt.subplot(1, len(Ls), i + 1)
        plt.semilogy()
        plt.title(f"$\\lambda={L}$")
        rew_lobjs = [lasso_objective(X, D, rew_sol, y, L) for rew_sol in rew_sols]
        rand_rew_lobjs = [
            [lasso_objective(X, D, rew_sol, y, L) for rew_sol in rew_sols]
            for rew_sols in rand_rew_sols
        ]
        rand_rew_lobjs = np.array(rand_rew_lobjs)
        plt.plot(
            rand_rew_lobjs[1:].T,
            color="black",
            alpha=0.5,
            linewidth=irlslw,
        )
        plt.plot(
            rand_rew_lobjs[0],
            color="black",
            alpha=0.5,
            label="IRLS with random inits",
            linewidth=irlslw,
        )
        plt.plot(
            rew_lobjs,
            # label=model_names[1],
            # ls,
            "-",
            label="IRLS with NTK init",
            color=rew_color,
            linewidth=lw,
        )
        las_lobj = lasso_objective(X, D, las_sol, y, L)
        plt.axhline(
            las_lobj,
            # label=model_names[0],
            label="Optimal",
            linestyle="--",
            color=las_color,
            linewidth=lw,
        )
        ntk_lobj = lasso_objective(X, D, ntk_sol, y, L)
        plt.axhline(
            ntk_lobj,
            # label=model_names[2],
            label=r"NTK weights",
            linestyle="--",
            color=ntk_color,
            linewidth=lw,
        )
        if i == len(Ls) - 1:
            plt.legend()
        plt.xlabel("Iteration $k$")
    if SAVE:
        plt.savefig(f"figs/{prefix}.png")
    plt.show()


# %%
n = 10
d = 5
max_neurons = int(1e3)
p_teacher = 10
X, D, G = get_data(n, d, max_neurons)
Gteacher = np.random.randn(p_teacher, d)
Wteacher = np.random.randn(p_teacher, d)
y = gated_relu_model(X, Gteacher, Wteacher)

st_prefix = f"st_pteacher{p_teacher}_d{d}_n{n}"
st_repeats = 5
st_rew_iters = 100
print(G.shape)

# %%
run_st_models_and_plot(
    X,
    y,
    Ls=[0.001, 0.01],
    eps=eps,
    prefix=st_prefix,
    repeats=st_repeats,
    rew_iters=st_rew_iters,
)

# %%
